import numpy as np
import torch
import scipy
import math
from scipy.interpolate import interp1d
import random

def gen_aug(sample, ssh_type, args=None):
    if ssh_type == 'na':
        return torch.from_numpy(sample)
    elif ssh_type == 'shuffle':
        return shuffle(sample)
    elif ssh_type == 'jit_scal':
        scale_sample = scaling(jitter(sample), sigma=2)
        return torch.from_numpy(scale_sample)
    elif ssh_type == 'perm_jit':
        return torch.from_numpy(jitter(permutation(sample, max_segments=5), sigma=0.5))
    elif ssh_type == 'resample':
        return torch.from_numpy(resample(sample))
    elif ssh_type == 'resample_2':
        lin_samples, lin_ratio = linear_resample(sample, args)
        return lin_samples, torch.from_numpy(lin_ratio)
    elif ssh_type == 'freq_shift':
        shifted_samples, lin_ratio = freq_shift(sample, args)
        return shifted_samples, torch.from_numpy(lin_ratio)
    elif ssh_type == 'noise':
        return torch.from_numpy(jitter(sample))
    elif ssh_type == 'scale':
        return torch.from_numpy(scaling(sample))
    elif ssh_type == 'negate':
        return torch.from_numpy(negated(sample))
    elif ssh_type == 'shift':
        return torch.from_numpy(shift_random(sample))
    elif ssh_type == 't_flip':
        return torch.from_numpy(time_flipped(sample).copy())
    elif ssh_type == 'rotation':
        if isinstance(multi_rotation(sample), np.ndarray):
            return torch.from_numpy(multi_rotation(sample))
        else:
            return multi_rotation(sample)
    elif ssh_type == 'perm':
        return torch.from_numpy(permutation(sample, max_segments=5))
    elif ssh_type == 't_warp':
        return torch.from_numpy(time_warp(sample))
    elif ssh_type == 'random_out':
        return torch.from_numpy(aug_random_zero_out(sample))
    elif ssh_type == 'random':
        return torch.from_numpy(_apply_augmentation(sample, args))        
    else:
        print('The task is not available!\n')


def aug_random_zero_out(x, max_len=0):
    N, L, _ = x.shape
    max_len = L/10
    out = x.copy()
    for i in range(N):
        # Generate random start and end points for the section to be zeroed out
        start = np.random.randint(0, L - 1)
        end = min(start + np.random.randint(1, max_len), L - 1)
        # Zero out the section
        out[i, :, start:end] = 0
    
    return out


def shuffle(x):
    sample_ssh = []
    for data in x:
        p = np.random.RandomState(seed=21).permutation(data.shape[1])
        data = data[:, p]
        sample_ssh.append(data)
    return torch.stack(sample_ssh)

def shift_random(x):
    sample_ssh = []
    for data in x:
        shift = np.random.randint(0, data.shape[0])
        data = np.roll(data, shift, axis=0)
        sample_ssh.append(data)
    return np.stack(sample_ssh)

def jitter(x, sigma=0.3):
    # https://arxiv.org/pdf/1706.00527.pdf
    return x + np.random.normal(loc=0., scale=sigma, size=x.shape)


def scaling(x, sigma=1.1): # apply same distortion to the signals from each sensor
    # https://arxiv.org/pdf/1706.00527.pdf
    factor = np.random.normal(loc=2., scale=sigma, size=(x.shape[0], x.shape[1]))
    ai = []
    for i in range(x.shape[2]):
        xi = x[:, :, i]
        ai.append(np.multiply(xi, factor[:, :])[:, :, np.newaxis])
    return np.concatenate((ai), axis=2)


def negated(X):
    return X * -1

def time_flipped(X):
    import pdb; pdb.set_trace()
    return np.flip(X,-1)

def soft_time_flipped(X):
    reverse_channels = torch.randperm(9)[:3]
    X[:, :, reverse_channels] = torch.flip(X[:, :, reverse_channels], dims=[1])
    return X

def permutation(x, max_segments=5, seg_mode="random"):
    orig_steps = np.arange(x.shape[1])
    num_segs = np.random.randint(1, max_segments, size=(x.shape[0]))
    ret = np.zeros_like(x)
    for i, pat in enumerate(x):
        if num_segs[i] > 1:
            if seg_mode == "random":
                split_points = np.random.choice(x.shape[1] - 2, num_segs[i] - 1, replace=False)
                split_points.sort()
                splits = np.split(orig_steps, split_points)
            else:
                splits = np.array_split(orig_steps, num_segs[i])
            np.random.shuffle(splits)
            warp = np.concatenate(splits).ravel()
            ret[i] = pat[warp]
        else:
            ret[i] = pat
    return ret

def resample(x):
    from scipy.interpolate import interp1d
    orig_steps = np.arange(x.shape[1])
    interp_steps = np.arange(0, orig_steps[-1]+0.001, 1/3)
    Interp = interp1d(orig_steps, x, axis=1)
    InterpVal = Interp(interp_steps)
    start = random.choice(orig_steps)
    resample_index = np.arange(start, 3 * x.shape[1], 2)[:x.shape[1]]
    return InterpVal[:, resample_index, :]

def freq_shift(x, args):
    from scipy.signal import hilbert
    t = np.linspace(0, x.shape[1]/args.fs, x.shape[1])
    w_0 = np.random.normal(-0.05, 0.1, size=x.shape[0]) # In Hz
    lin_ratio, lin_samples = np.zeros((x.shape[0], 1)), np.zeros((x.shape[0], x.shape[1]))
    for i in range(x.shape[0]):
        asignal = hilbert(x[i])
        shifted_signal = asignal * np.exp(2j * np.pi * w_0[i] * t)
        lin_samples[i,:] = np.real(shifted_signal)
        lin_ratio[i, :] = w_0[i] 
    return lin_samples, lin_ratio

def linear_resample(x, args):
    from scipy.interpolate import interp1d
    original_duration = x.shape[1] / args.fs
    cut_window = np.random.uniform(low=2*original_duration/3, high=original_duration, size=x.shape[0])
    lin_ratio, lin_samples = np.zeros((x.shape[0], 1)), np.zeros((x.shape[0], x.shape[1]))
    for i in range(x.shape[0]):
        x_shorter = x[i, :int(cut_window[i] * args.fs)]
        lin_ratio[i] = x_shorter.shape[0] / x.shape[1]
        orig_steps = np.arange(0, x_shorter.shape[0]/args.fs, 1/args.fs)
        Interp = interp1d(orig_steps, x_shorter)

        interp_steps = np.linspace(0, orig_steps[-1], x.shape[1])
        lin_samples[i,:] = Interp(interp_steps)
    return lin_samples, lin_ratio

def multi_rotation(x):
    n_channel = x.shape[2]
    n_rot = n_channel // 3
    x_rot = np.array([])
    for i in range(n_rot):
        x_rot = np.concatenate((x_rot, rotation(x[:, :, i * 3:i * 3 + 3])), axis=2) if x_rot.size else rotation(
            x[:, :, i * 3:i * 3 + 3])
    return x_rot

def rotation(X):
    """
    Applying a random 3D rotation
    """
    axes = np.random.uniform(low=-1, high=1, size=(X.shape[0], X.shape[2]))
    angles = np.random.uniform(low=-np.pi, high=np.pi, size=(X.shape[0]))
    matrices = axis_angle_to_rotation_matrix_3d_vectorized(axes, angles)
    return np.matmul(X, matrices)

def axis_angle_to_rotation_matrix_3d_vectorized(axes, angles):
    """
    Get the rotational matrix corresponding to a rotation of (angle) radian around the axes
    Reference: the Transforms3d package - transforms3d.axangles.axangle2mat
    Formula: http://en.wikipedia.org/wiki/Rotation_matrix#Axis_and_angle
    """
    axes = axes / np.linalg.norm(axes, ord=2, axis=1, keepdims=True)
    x = axes[:, 0]; y = axes[:, 1]; z = axes[:, 2]
    c = np.cos(angles)
    s = np.sin(angles)
    C = 1 - c

    xs = x*s;   ys = y*s;   zs = z*s
    xC = x*C;   yC = y*C;   zC = z*C
    xyC = x*yC; yzC = y*zC; zxC = z*xC

    m = np.array([
        [ x*xC+c,   xyC-zs,   zxC+ys ],
        [ xyC+zs,   y*yC+c,   yzC-xs ],
        [ zxC-ys,   yzC+xs,   z*zC+c ]])
    matrix_transposed = np.transpose(m, axes=(2,0,1))
    return matrix_transposed

def get_cubic_spline_interpolation(x_eval, x_data, y_data):
    """
    Get values for the cubic spline interpolation
    """
    cubic_spline = scipy.interpolate.CubicSpline(x_data, y_data)
    return cubic_spline(x_eval)


def time_warp(X, sigma=0.2, num_knots=4):
    """
    Stretching and warping the time-series
    """
    time_stamps = np.arange(X.shape[1])
    knot_xs = np.arange(0, num_knots + 2, dtype=float) * (X.shape[1] - 1) / (num_knots + 1)
    spline_ys = np.random.normal(loc=1.0, scale=sigma, size=(X.shape[0] * X.shape[2], num_knots + 2))

    spline_values = np.array([get_cubic_spline_interpolation(time_stamps, knot_xs, spline_ys_individual) for spline_ys_individual in spline_ys])

    cumulative_sum = np.cumsum(spline_values, axis=1)
    distorted_time_stamps_all = cumulative_sum / cumulative_sum[:, -1][:, np.newaxis] * (X.shape[1] - 1)

    X_transformed = np.empty(shape=X.shape)
    for i, distorted_time_stamps in enumerate(distorted_time_stamps_all):
        X_transformed[i // X.shape[2], :, i % X.shape[2]] = np.interp(time_stamps, distorted_time_stamps, X[i // X.shape[2], :, i % X.shape[2]])
    return X_transformed


# Define a mapping from augmentation names to functions
AUGMENTATION_MAP = {
    'perm_jit': lambda sample: jitter(permutation(sample, max_segments=5), sigma=0.5),
    'resample': resample,
    'jitter': jitter,
    'scale': scaling,
    'shift': shift_random,
    'rotation': multi_rotation,
    'negate': negated,
}

def _apply_augmentation(x, args):
    if args.dataset in ['ieee_small', 'ieee_big', 'dalia']:
        augmentations = ['perm_jit', 'jitter', 'scale', 'shift'] # with scale is better
        # augmentations = ['perm_jit', 'jitter', 'shift']
        return _apply_random_augmentation(x, augmentations)
    elif args.dataset in ['usc', 'hhar', 'clemson']:
        augmentations = ['perm_jit', 'jitter', 'scale', 'shift', 'rotation']
        return _apply_random_augmentation(x, augmentations)
    elif args.dataset in ['cpsc', 'chapman', 'sleep']:
        augmentations = ['resample', 'jitter','scale','shift', 'negate'] # negate increases the performance
        # augmentations = ['resample', 'jitter','scale','shift']
        return _apply_random_augmentation(x, augmentations)

def _apply_random_augmentation(sample, augmentations):
    aug_name = random.choice(augmentations)
    # Get the corresponding function from the mapping
    aug_function = AUGMENTATION_MAP.get(aug_name)
    if aug_function is None:
        raise ValueError(f"Augmentation '{aug_name}' is not defined.")
    return aug_function(sample)


####################### TF-C ########################
'''
https://github.com/mims-harvard/TFC-pretraining/blob/main/code/TFC/augmentations.py
'''

def DataTransform_TD(sample):
    """Simplely use the jittering augmentation. Feel free to add more autmentations you want,
    but we noticed that in TF-C framework, the augmentation has litter impact on the final tranfering performance."""
    noise = torch.normal(mean=0.0, std=0.8, size=sample.size()).to(sample.device)
    return sample + noise


def DataTransform_TD_bank(sample, config):
    """Augmentation bank that includes four augmentations and randomly select one as the positive sample.
    You may use this one the replace the above DataTransform_TD function."""
    aug_1 = jitter(sample, config.augmentation.jitter_ratio)
    aug_2 = scaling(sample, config.augmentation.jitter_scale_ratio)
    aug_3 = permutation(sample, max_segments=config.augmentation.max_seg)
    aug_4 = masking(sample, keepratio=0.9)

    li = np.random.randint(0, 4, size=[sample.shape[0]])
    li_onehot = one_hot_encoding(li)
    aug_1 = aug_1 * li_onehot[:, 0][:, None, None]  # the rows that are not selected are set as zero.
    aug_2 = aug_2 * li_onehot[:, 0][:, None, None]
    aug_3 = aug_3 * li_onehot[:, 0][:, None, None]
    aug_4 = aug_4 * li_onehot[:, 0][:, None, None]
    aug_T = aug_1 + aug_2 + aug_3 + aug_4
    return aug_T

def DataTransform_FD(sample):
    """Weak and strong augmentations in Frequency domain """
    aug_1 = remove_frequency(sample, pertub_ratio=0.1)
    aug_2 = add_frequency(sample, pertub_ratio=0.1)
    aug_F = aug_1 + aug_2
    return aug_F

def remove_frequency(x, pertub_ratio=0.0):
    mask = torch.cuda.FloatTensor(x.shape).uniform_() > pertub_ratio # maskout_ratio are False
    mask = mask.to(x.device)
    return x*mask

def add_frequency(x, pertub_ratio=0.0):

    mask = torch.cuda.FloatTensor(x.shape).uniform_() > (1-pertub_ratio) # only pertub_ratio of all values are True
    mask = mask.to(x.device)
    max_amplitude = x.max()
    mask_shape = torch.rand(mask.shape).to(x.device)
    random_am = mask_shape*(max_amplitude*0.1)
    pertub_matrix = mask*random_am
    return x+pertub_matrix


######################### MTM #########################

def data_transform_masked4cl(sample, masking_ratio, lm, positive_nums=None, distribution='geometric'):
    """Masked time series in time dimension"""
    if positive_nums is None:
        positive_nums = math.ceil(1.5 / (1 - masking_ratio))

    sample = sample.permute(0, 2, 1)

    sample_repeat = sample.repeat(positive_nums, 1, 1)

    mask = noise_mask(sample_repeat, masking_ratio, lm, distribution=distribution)
    x_masked = mask.to(sample_repeat.device) * sample_repeat

    return x_masked.permute(0, 2, 1), mask.permute(0, 2, 1)


def geom_noise_mask_single(L, lm, masking_ratio):
    """
    Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio`
    proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution.
    Args:
        L: length of mask and sequence to be masked
        lm: average length of masking subsequences (streaks of 0s)
        masking_ratio: proportion of L to be masked
    Returns:
        (L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L
    """
    keep_mask = np.ones(L, dtype=bool)
    p_m = 1 / lm  # probability of each masking sequence stopping. parameter of geometric distribution.
    p_u = p_m * masking_ratio / (
            1 - masking_ratio)  # probability of each unmasked sequence stopping. parameter of geometric distribution.
    p = [p_m, p_u]

    # Start in state 0 with masking_ratio probability
    state = int(np.random.rand() > masking_ratio)  # state 0 means masking, 1 means not masking
    for i in range(L):
        keep_mask[i] = state  # here it happens that state and masking value corresponding to state are identical
        if np.random.rand() < p[state]:
            state = 1 - state

    return keep_mask


def noise_mask(X, masking_ratio=0.25, lm=3, distribution='geometric', exclude_feats=None):
    """
    Creates a random boolean mask of the same shape as X, with 0s at places where a feature should be masked.
    Args:
        X: (seq_length, feat_dim) numpy array of features corresponding to a single sample
        masking_ratio: proportion of seq_length to be masked. At each time step, will also be the proportion of
            feat_dim that will be masked on average
        lm: average length of masking subsequences (streaks of 0s). Used only when `distribution` is 'geometric'.
        distribution: whether each mask sequence element is sampled independently at random, or whether
            sampling follows a markov chain (and thus is stateful), resulting in geometric distributions of
            masked squences of a desired mean length `lm`
        exclude_feats: iterable of indices corresponding to features to be excluded from masking (i.e. to remain all 1s)
    Returns:
        boolean numpy array with the same shape as X, with 0s at places where a feature should be masked
    """
    if exclude_feats is not None:
        exclude_feats = set(exclude_feats)

    if distribution == 'geometric':  # stateful (Markov chain)
        mask = geom_noise_mask_single(X.shape[0] * X.shape[1] * X.shape[2], lm, masking_ratio)
        mask = mask.reshape(X.shape[0], X.shape[1], X.shape[2])
    elif distribution == 'masked_tail':
        mask = np.ones(X.shape, dtype=bool)
        for m in range(X.shape[0]):  # feature dimension

            keep_mask = np.zeros_like(mask[m, :], dtype=bool)
            n = math.ceil(keep_mask.shape[1] * (1 - masking_ratio))
            keep_mask[:, :n] = True
            mask[m, :] = keep_mask  # time dimension
    elif distribution == 'masked_head':
        mask = np.ones(X.shape, dtype=bool)
        for m in range(X.shape[0]):  # feature dimension

            keep_mask = np.zeros_like(mask[m, :], dtype=bool)
            n = math.ceil(keep_mask.shape[1] * masking_ratio)
            keep_mask[:, n:] = True
            mask[m, :] = keep_mask  # time dimension
    else:  # each position is independent Bernoulli with p = 1 - masking_ratio
        mask = np.random.choice(np.array([True, False]), size=X.shape, replace=True,
                                p=(1 - masking_ratio, masking_ratio))
    return torch.tensor(mask)    